#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>
#include <time.h>
#include <math.h>
#include <openssl/evp.h>

#include "rng.h"
#include "api.h"
#include "gmp.h"
#include "kaz_api.h"

void KAZ_KEM_RANDOM(mpz_t lb, mpz_t ub, mpz_t out)
{
	mpz_t range, rand_in_range;
    gmp_randstate_t state;

    mpz_inits(range, rand_in_range, NULL);

    // Compute range = (max - min + 1)
    mpz_sub(range, ub, lb);
    mpz_add_ui(range, range, 1);

    // Initialize random generator
    gmp_randinit_default(state);
    gmp_randseed_ui(state, 123456789); //time(NULL)

    // Generate random number: lb ≤ rand_in_range < ub
    mpz_urandomm(rand_in_range, state, range);

    // result = lb + rand_in_range
    mpz_add(out, lb, rand_in_range);

    // Cleanup
    mpz_clears(range, rand_in_range, NULL);
    gmp_randclear(state);
}

int KAZ_KEM_KEYGEN(unsigned char *kaz_kem_public_key, unsigned char *kaz_kem_private_key)
{
	int ret=0;
    mpz_t N, g1, g2, Og1N, Og2N, Og3N, a1, a2, e1, e2;
    mpz_t tmp, lowerbound, upperbound;

    mpz_inits(N, g1, g2, Og1N, Og2N, Og3N, a1, a2, e1, e2, NULL);
    mpz_inits(tmp, lowerbound, upperbound, NULL);

    //Get all system parameters and precomputed parameters
    mpz_set_str(N, KAZ_KEM_SP_N, 10);
    mpz_set_str(g1, KAZ_KEM_SP_g1, 10);
    mpz_set_str(g2, KAZ_KEM_SP_g2, 10);
	mpz_set_str(Og1N, KAZ_KEM_SP_Og1N, 10);
	mpz_set_str(Og2N, KAZ_KEM_SP_Og2N, 10);
	mpz_set_str(Og3N, KAZ_KEM_SP_Og3N, 10);
	
	// Generate a1, a2 randomly
	mpz_ui_pow_ui(lowerbound, 2, KAZ_KEM_SP_LOg1N-2);
	mpz_set(upperbound, Og1N);
	KAZ_KEM_RANDOM(lowerbound, upperbound, a1);
	
	mpz_ui_pow_ui(lowerbound, 2, KAZ_KEM_SP_LOg2N-2);
	mpz_set(upperbound, Og2N);
	KAZ_KEM_RANDOM(lowerbound, upperbound, a2);
	
	// Compute e1
	mpz_powm(e1, g1, a1, N);
	mpz_mul_ui(tmp, a2, 2);
	mpz_powm(tmp, g2, tmp, N);
	mpz_mul(e1, e1, tmp);
	mpz_mod(e1, e1, N);
	
	// Compute e2
	mpz_powm(e2, g1, a2, N);
	mpz_powm(tmp, g2, a1, N);
	mpz_mul(e2, e2, tmp);
	mpz_mod(e2, e2, N);
	
    // Set kaz_kem_public_key={e1, e2} & kaz_kem_private_key=(a1, a2)
    size_t E1SIZE=mpz_sizeinbase(e1, 16);
	size_t E2SIZE=mpz_sizeinbase(e2, 16);
	size_t a1SIZE=mpz_sizeinbase(a1, 16);
	size_t a2SIZE=mpz_sizeinbase(a2, 16);

	unsigned char *E1BYTE=NULL;
	unsigned char *E2BYTE=NULL;
	unsigned char *a1BYTE=NULL;
	unsigned char *a2BYTE=NULL;

	E1BYTE=(unsigned char*) malloc(E1SIZE*sizeof(unsigned char));
	E2BYTE=(unsigned char*) malloc(E2SIZE*sizeof(unsigned char));
	a1BYTE=(unsigned char*) malloc(a1SIZE*sizeof(unsigned char));
	a2BYTE=(unsigned char*) malloc(a2SIZE*sizeof(unsigned char));

	if (!E1BYTE || !E2BYTE || !a1BYTE || !a2BYTE) {
        fprintf(stderr, "KAZ-KEM-KEYGEN: Memory allocation failed.\n");
		ret=-4;
        goto kaz_kem_cleanup;
    }
	
	mpz_export(E1BYTE, &E1SIZE, 1, sizeof(char), 0, 0, e1);
	mpz_export(E2BYTE, &E2SIZE, 1, sizeof(char), 0, 0, e2);
	mpz_export(a1BYTE, &a1SIZE, 1, sizeof(char), 0, 0, a1);
	mpz_export(a2BYTE, &a2SIZE, 1, sizeof(char), 0, 0, a2);

	memset(kaz_kem_public_key, 0, KAZ_KEM_PUBLICKEY_BYTES*2);
	memset(kaz_kem_private_key, 0, KAZ_KEM_PRIVATEKEY_BYTES*2);

	//for(int i=0; i<KAZ_KEM_PUBLICKEY_BYTES*2; i++) kaz_kem_public_key[i]=0;
	//for(int i=0; i<KAZ_KEM_PRIVATEKEY_BYTES*2; i++) kaz_kem_private_key[i]=0;

	int je=(KAZ_KEM_PUBLICKEY_BYTES*2)-1;
	
	for(int i=E2SIZE-1; i>=0; i--){
		kaz_kem_public_key[je]=E2BYTE[i];
		je--;
	}

	je=(KAZ_KEM_PUBLICKEY_BYTES*2)-KAZ_KEM_PUBLICKEY_BYTES-1;
	for(int i=E1SIZE-1; i>=0; i--){
		kaz_kem_public_key[je]=E1BYTE[i];
		je--;
	}

	je=(KAZ_KEM_PRIVATEKEY_BYTES*2)-1;
	for(int i=a2SIZE-1; i>=0; i--){
		kaz_kem_private_key[je]=a2BYTE[i];
		je--;
	}

	je=(KAZ_KEM_PRIVATEKEY_BYTES*2)-KAZ_KEM_PRIVATEKEY_BYTES-1;
	for(int i=a1SIZE-1; i>=0; i--){
		kaz_kem_private_key[je]=a1BYTE[i];
		je--;
	}

	kaz_kem_cleanup:
		mpz_clears(N, g1, g2, Og1N, Og2N, a1, a2, e1, e2, NULL);
		mpz_clears(tmp, lowerbound, upperbound, NULL);
		free(E1BYTE);
		free(E2BYTE);
		free(a1BYTE);
		free(a2BYTE);

	return ret;
}

int KAZ_KEM_ENCAPSULATION(unsigned char *encap, unsigned long long *encaplen, 
						  const unsigned char *m, unsigned long long mlen, 
						  const unsigned char *pk)
{
	int ret=0;
    mpz_t N, g1, g2, Og1N, Og2N, e1, e2, b1, b2, B1, B2, M, ENCAP;
    mpz_t tmp, lowerbound, upperbound;

    mpz_inits(N, g1, g2, Og1N, Og2N, e1, e2, b1, b2, B1, B2, M, ENCAP, NULL);
    mpz_inits(tmp, lowerbound, upperbound, NULL);

    //Get all system parameters and precomputed parameters
    mpz_set_str(N, KAZ_KEM_SP_N, 10);
    mpz_set_str(g1, KAZ_KEM_SP_g1, 10);
    mpz_set_str(g2, KAZ_KEM_SP_g2, 10);
	mpz_set_str(Og1N, KAZ_KEM_SP_Og1N, 10);
	mpz_set_str(Og2N, KAZ_KEM_SP_Og2N, 10);

	// Get kaz_kem_public_key={A1, A2, A3} 
	unsigned char *E1BYTE=NULL;
	unsigned char *E2BYTE=NULL;

	E1BYTE=(unsigned char*) malloc((KAZ_KEM_PUBLICKEY_BYTES)*sizeof(unsigned char));
	E2BYTE=(unsigned char*) malloc((KAZ_KEM_PUBLICKEY_BYTES)*sizeof(unsigned char));
	
	//for(int i=0; i<KAZ_KEM_PUBLICKEY_BYTES; i++) E1BYTE[i]=0;
	//for(int i=0; i<KAZ_KEM_PUBLICKEY_BYTES; i++) E2BYTE[i]=0;

	memset(E1BYTE, 0, KAZ_KEM_PUBLICKEY_BYTES);
	memset(E2BYTE, 0, KAZ_KEM_PUBLICKEY_BYTES);

	for(int i=0; i<KAZ_KEM_PUBLICKEY_BYTES; i++) E1BYTE[i]=pk[i];
	for(int i=0; i<KAZ_KEM_PUBLICKEY_BYTES; i++) E2BYTE[i]=pk[i+KAZ_KEM_PUBLICKEY_BYTES];
	
	mpz_import(e1, KAZ_KEM_PUBLICKEY_BYTES, 1, sizeof(char), 0, 0, E1BYTE);
	mpz_import(e2, KAZ_KEM_PUBLICKEY_BYTES, 1, sizeof(char), 0, 0, E2BYTE);
	
	// Generate b1, b2 randomly
	mpz_ui_pow_ui(lowerbound, 2, KAZ_KEM_SP_LOg1N-2);
	mpz_set(upperbound, Og1N);
	KAZ_KEM_RANDOM(lowerbound, upperbound, b1);
	
	mpz_ui_pow_ui(lowerbound, 2, KAZ_KEM_SP_LOg2N-2);
	mpz_set(upperbound, Og2N);
	KAZ_KEM_RANDOM(lowerbound, upperbound, b2);
	
	// Compute B1
	mpz_powm(B1, g1, b1, N);
	mpz_powm(tmp, g2, b2, N);
	mpz_mul(B1, B1, tmp);
	mpz_mod(B1, B1, N);
	
	// Compute B2
	mpz_powm(B2, g1, b2, N);
	mpz_mul_ui(tmp, b1, 2);
	mpz_powm(tmp, g2, tmp, N);
	mpz_mul(B2, B2, tmp);
	mpz_mod(B2, B2, N);
	
	// Compute Encapsulation
	mpz_import(M, KAZ_KEM_GENERAL_BYTES, 1, sizeof(char), 0, 0, m);
	
	mpz_powm(ENCAP, e1, b1, N);
	mpz_powm(tmp, e2, b2, N);
	mpz_mul(ENCAP, ENCAP, tmp);
	mpz_mod(ENCAP, ENCAP, N);
	mpz_add(ENCAP, ENCAP, M);
	mpz_mod(ENCAP, ENCAP, N);

	// Set kaz_kem_encapsulation={ENCAP, B1, B2}, kaz_kem_ephemeral_public={B1, B2} & kaz_kem_ephemeral_private=(b1, b2)
	
	size_t ENCAPSIZE=mpz_sizeinbase(ENCAP, 16);
	size_t B1SIZE=mpz_sizeinbase(B1, 16);
	size_t B2SIZE=mpz_sizeinbase(B2, 16);

	unsigned char *ENCAPBYTE=NULL;
	unsigned char *B1BYTE=NULL;
	unsigned char *B2BYTE=NULL;
	
	ENCAPBYTE=(unsigned char*) malloc(ENCAPSIZE*sizeof(unsigned char));
	B1BYTE=(unsigned char*) malloc(B1SIZE*sizeof(unsigned char));
	B2BYTE=(unsigned char*) malloc(B2SIZE*sizeof(unsigned char));
	
	if (!ENCAPBYTE || !B1BYTE || !B2BYTE) {
        fprintf(stderr, "KAZ-KEM-ENCAPSULATION: Memory allocation failed.\n");
		ret=-4;
        goto kaz_kem_cleanup;
    }
	
	mpz_export(ENCAPBYTE, &ENCAPSIZE, 1, sizeof(char), 0, 0, ENCAP);
	mpz_export(B1BYTE, &B1SIZE, 1, sizeof(char), 0, 0, B1);
	mpz_export(B2BYTE, &B2SIZE, 1, sizeof(char), 0, 0, B2);

	memset(encap, 0, KAZ_KEM_GENERAL_BYTES+(KAZ_KEM_EPHERMERAL_PUBLIC_BYTES*2));

	//for(int i=0; i<KAZ_KEM_GENERAL_BYTES+(KAZ_KEM_EPHERMERAL_PUBLIC_BYTES*2); i++) encap[i]=0;

	int je=(KAZ_KEM_GENERAL_BYTES+(KAZ_KEM_EPHERMERAL_PUBLIC_BYTES*2))-1;
	
	for(int i=B2SIZE-1; i>=0; i--){
		encap[je]=B2BYTE[i];
		je--;
	}

	je=(KAZ_KEM_GENERAL_BYTES+(KAZ_KEM_EPHERMERAL_PUBLIC_BYTES*2))-KAZ_KEM_EPHERMERAL_PUBLIC_BYTES-1;
	for(int i=B1SIZE-1; i>=0; i--){
		encap[je]=B1BYTE[i];
		je--;
	}
	
	je=(KAZ_KEM_GENERAL_BYTES+(KAZ_KEM_EPHERMERAL_PUBLIC_BYTES*2))-KAZ_KEM_EPHERMERAL_PUBLIC_BYTES-KAZ_KEM_EPHERMERAL_PUBLIC_BYTES-1;
	for(int i=ENCAPSIZE-1; i>=0; i--){
		encap[je]=ENCAPBYTE[i];
		je--;
	}
	
	*encaplen=KAZ_KEM_GENERAL_BYTES+(KAZ_KEM_EPHERMERAL_PUBLIC_BYTES*2);

	kaz_kem_cleanup:
		mpz_clears(N, g1, g2, Og1N, Og2N, e1, e2, b1, b2, B1, B2, M, ENCAP, NULL);
		mpz_clears(tmp, lowerbound, upperbound, NULL);
		free(ENCAPBYTE);
		free(B1BYTE);
		free(B2BYTE);

	return ret;
}

int KAZ_KEM_DECAPSULATION(unsigned char *decap, unsigned long long *decaplen, 
						  const unsigned char *encap, unsigned long long encaplen, 
						  const unsigned char *sk)
{
	int ret=0;
    mpz_t N, B1, B2, a1, a2, ENCAP, DECAP;
    mpz_t tmp;

    mpz_inits(N, B1, B2, a1, a2, ENCAP, DECAP, NULL);
    mpz_inits(tmp, NULL);

    //Get all system parameters and precomputed parameters
    mpz_set_str(N, KAZ_KEM_SP_N, 10);
	
	// Get kaz_kem_encapsulate={ENCAP, B1, B2}
	// Get kaz_kem_private_key={a1, a2} 
	unsigned char *a1BYTE=NULL;
	unsigned char *a2BYTE=NULL;
	unsigned char *B1BYTE=NULL;
	unsigned char *B2BYTE=NULL;
	unsigned char *ENCAPBYTE=NULL;

	a1BYTE=(unsigned char*) malloc((KAZ_KEM_PRIVATEKEY_BYTES)*sizeof(unsigned char));
	a2BYTE=(unsigned char*) malloc((KAZ_KEM_PRIVATEKEY_BYTES)*sizeof(unsigned char));
	B1BYTE=(unsigned char*) malloc((KAZ_KEM_EPHERMERAL_PUBLIC_BYTES)*sizeof(unsigned char));
	B2BYTE=(unsigned char*) malloc((KAZ_KEM_EPHERMERAL_PUBLIC_BYTES)*sizeof(unsigned char));
	ENCAPBYTE=(unsigned char*) malloc((KAZ_KEM_GENERAL_BYTES)*sizeof(unsigned char));

	if (!a1BYTE || !a2BYTE || !ENCAPBYTE || !B1BYTE || !B2BYTE) {
        fprintf(stderr, "KAZ-KEM-DECAPSULATION: Memory allocation failed.\n");
		ret=-4;
        goto kaz_kem_cleanup;
    }

	//for(int i=0; i<KAZ_KEM_GENERAL_BYTES; i++) ENCAPBYTE[i]=0;
	//for(int i=0; i<KAZ_KEM_EPHERMERAL_PUBLIC_BYTES; i++) B1BYTE[i]=0;
	//for(int i=0; i<KAZ_KEM_EPHERMERAL_PUBLIC_BYTES; i++) B2BYTE[i]=0;
	//for(int i=0; i<KAZ_KEM_PRIVATEKEY_BYTES; i++) a1BYTE[i]=0;
	//for(int i=0; i<KAZ_KEM_PRIVATEKEY_BYTES; i++) a2BYTE[i]=0;

	memset(a1BYTE, 0, KAZ_KEM_PRIVATEKEY_BYTES);
	memset(a2BYTE, 0, KAZ_KEM_PRIVATEKEY_BYTES);
	memset(B1BYTE, 0, KAZ_KEM_EPHERMERAL_PUBLIC_BYTES);
	memset(B2BYTE, 0, KAZ_KEM_EPHERMERAL_PUBLIC_BYTES);
	memset(ENCAPBYTE, 0, KAZ_KEM_GENERAL_BYTES);
	
	for(int i=0; i<KAZ_KEM_PRIVATEKEY_BYTES; i++) a1BYTE[i]=sk[i];
	for(int i=0; i<KAZ_KEM_PRIVATEKEY_BYTES; i++) a2BYTE[i]=sk[i+KAZ_KEM_PRIVATEKEY_BYTES];
	for(int i=0; i<KAZ_KEM_EPHERMERAL_PUBLIC_BYTES; i++) B1BYTE[i]=encap[i+KAZ_KEM_EPHERMERAL_PUBLIC_BYTES];
	for(int i=0; i<KAZ_KEM_EPHERMERAL_PUBLIC_BYTES; i++) B2BYTE[i]=encap[i+(KAZ_KEM_EPHERMERAL_PUBLIC_BYTES*2)];
	for(int i=0; i<KAZ_KEM_GENERAL_BYTES; i++) ENCAPBYTE[i]=encap[i];
	
	mpz_import(a1, KAZ_KEM_PRIVATEKEY_BYTES, 1, sizeof(char), 0, 0, a1BYTE);
	mpz_import(a2, KAZ_KEM_PRIVATEKEY_BYTES, 1, sizeof(char), 0, 0, a2BYTE);
	mpz_import(B1, KAZ_KEM_EPHERMERAL_PUBLIC_BYTES, 1, sizeof(char), 0, 0, B1BYTE);
	mpz_import(B2, KAZ_KEM_EPHERMERAL_PUBLIC_BYTES, 1, sizeof(char), 0, 0, B2BYTE);
	mpz_import(ENCAP, KAZ_KEM_GENERAL_BYTES, 1, sizeof(char), 0, 0, ENCAPBYTE);
	
	// Compute Decapsulation
	mpz_powm(DECAP, B1, a1, N);
	mpz_powm(tmp, B2, a2, N);
	mpz_mul(DECAP, DECAP, tmp);
	mpz_mod(DECAP, DECAP, N);
	mpz_sub(DECAP, ENCAP, DECAP);
	mpz_mod(DECAP, DECAP, N);
	
	size_t DECAPSIZE=mpz_sizeinbase(DECAP, 16);
	mpz_export(decap, &DECAPSIZE, 1, sizeof(char), 0, 0, DECAP);
	
	*decaplen=KAZ_KEM_GENERAL_BYTES;

	kaz_kem_cleanup:
		mpz_clears(N, B1, B2, a1, a2, ENCAP, DECAP, NULL);
		mpz_clears(tmp, NULL);
		free(a1BYTE);
		free(a2BYTE);
		free(B1BYTE);
		free(B2BYTE);
		free(ENCAPBYTE);
	
	return ret;
}